#!/usr/bin/env python3
# Copyright (c) 2021 maminjie <canpool@163.com>
# SPDX-License-Identifier: MulanPSL-2.0


import argparse
import os
import sys


def parse_command_line():
    params = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    params.add_argument("-c", "--command", type=str, choices=['check', 'group', 'info'], required=True,
                        help="The subcommand, as follows:\n"
                             "check - Check differences between patch and source code\n"
                             "group - Group patches by modified files\n"
                             "info  - Show informations of patch file(s)\n"
                             "        flags:\n"
                             "            - Means to delete file\n"
                             "            + Means to add file\n"
                             "            * Means to modify file\n")
    params.add_argument("-p", "--patch", type=str, required=True,
                        help="The patch file or directory")
    params.add_argument("-s", "--source", type=str,
                        help="The source code directory")
    args = params.parse_args()
    return args


def open_patch(patch_file):
    return open(patch_file, "r", encoding="utf8", errors="ignore")

def match_source_file(file_part, src_dir):
    for path, _, files in os.walk(src_dir):
        for file in files:
            file_path = path + '/' + file
            if file_path.endswith(file_part):
                return file_path
    return ""


def check_patch_file(patch_file, src_dir):
    patch_file_hdl = open_patch(patch_file)
    if not patch_file_hdl:
        return (-1, 0)
    patch_cnt = 0
    match_cnt = 0
    prev_line = ""
    line = patch_file_hdl.readline()
    while line:
        source_file = ""
        if line.startswith("--- ") and len(line.split()) > 1:
            prev_line = line
            line = patch_file_hdl.readline()
            if not line:
                break
            if line.startswith("+++ ") and len(line.split()) > 1:
                old_file = prev_line.split()[1]
                new_file = line.split()[1]
                if old_file == "/dev/null": # add file
                    if len(new_file.split('/')) > 1:
                        new_file = '/'.join(new_file.split('/')[1:])
                    source_file = match_source_file(new_file, src_dir)
                    line = patch_file_hdl.readline()
                    while line and not line.startswith("--- "):
                        if line.startswith("+"):
                            patch_cnt += 1
                            if source_file:
                                match_cnt += 1
                        line = patch_file_hdl.readline()
                    else:
                        continue
                elif new_file == "/dev/null": # delete file
                    if len(old_file.split('/')) > 1:
                        old_file = '/'.join(old_file.split('/')[1:])
                    source_file = match_source_file(old_file, src_dir)
                    line = patch_file_hdl.readline()
                    while line and not line.startswith("--- "):
                        if line.startswith("-") and not line.startswith("--"):
                            patch_cnt += 1
                            if not source_file:
                                match_cnt += 1
                        line = patch_file_hdl.readline()
                    else:
                        continue
                else: # modify file
                    if len(new_file.split('/')) > 1:
                        new_file = '/'.join(new_file.split('/')[1:])
                    source_file = match_source_file(new_file, src_dir)
                    src_file_hdl = None
                    line = patch_file_hdl.readline()
                    while line and not line.startswith("--- "):
                        if line.startswith("+"):
                            patch_cnt += 1
                            if source_file:
                                if not src_file_hdl:
                                    src_file_hdl = open_patch(source_file)
                                if src_file_hdl:
                                    src_line = src_file_hdl.readline()
                                    while src_line:
                                        if src_line == line[1:]:
                                            match_cnt += 1
                                            break
                                        else:
                                            src_line = src_file_hdl.readline()
                                            continue
                                    if not src_line:
                                        src_file_hdl.close()
                                        src_file_hdl = None
                        line = patch_file_hdl.readline()
                    else:
                        if src_file_hdl:
                            src_file_hdl.close()
                        continue
            else:
                line = patch_file_hdl.readline()
                continue
        line = patch_file_hdl.readline()
    patch_file_hdl.close()
    return (patch_cnt, match_cnt)


def get_patch_files(patch_path):
    patch_files = []
    if os.path.isfile(patch_path):
        if patch_path.endswith(".patch"):
            patch_files.append(os.path.abspath(patch_path))
    elif os.path.isdir(patch_path):
        for path, _, files in os.walk(patch_path):
            for file in files:
                if file.endswith(".patch"):
                    patch_files.append(path + "/" + file)
    return patch_files


def do_patch_check(patch_files, args):
    src_dir = args.source
    if not src_dir:
        print("error: the following arguments are required: -s/--source")
        sys.exit(1)
    result = {}
    for file in patch_files:
        r = check_patch_file(file, src_dir)
        result[file.split('/')[-1]] = r
    print("patch,lines,match,percent")
    for patch, r in result.items():
        percent = "0"
        if r[0] != -1 and r[0] != 0:
            percent = "{:.1%}".format(r[1] / r[0])
        print("{},{},{},{}".format(patch, r[0], r[1], percent))


def do_patch_group(patch_files, args):
    mfile_patch = dict()
    patch_mfile = dict()
    for patch in patch_files:
        file_hdl = open_patch(patch)
        if not file_hdl:
            continue
        line = file_hdl.readline()
        while line:
            if line.startswith("--- ") and len(line.split()) > 1:
                prev_line = line
                line = file_hdl.readline()
                if not line:
                    break
                if line.startswith("+++ ") and len(line.split()) > 1:
                    source_file = ""
                    old_file = prev_line.split()[1]
                    new_file = line.split()[1]
                    if old_file == "/dev/null": # add file
                        if len(new_file.split('/')) > 1:
                            new_file = '/'.join(new_file.split('/')[1:])
                        source_file = new_file
                    elif new_file == "/dev/null": # delete file
                        if len(old_file.split('/')) > 1:
                            old_file = '/'.join(old_file.split('/')[1:])
                        source_file = old_file
                    else: # modify file
                        if len(new_file.split('/')) > 1:
                            new_file = '/'.join(new_file.split('/')[1:])
                        source_file = new_file
                    if source_file:
                        if source_file not in mfile_patch:
                            mfile_patch[source_file] = {patch}
                        else:
                            mfile_patch[source_file].add(patch)
                        if patch not in patch_mfile:
                            patch_mfile[patch] = {source_file}
                        else:
                            patch_mfile[patch].add(source_file)
            line = file_hdl.readline()
        file_hdl.close()
        file_hdl = None
        if patch not in patch_mfile:
            patch_mfile[patch] = set()

    patches = list(patch_mfile.keys())
    groups = list()
    while patches:
        group = set()
        group.add(patches[0])
        iter_files = set(patch_mfile[patches[0]])
        files = iter_files
        while True:
            tmp = set()
            for file in iter_files:
                group = group.union(mfile_patch[file])
                for patch in mfile_patch[file]:
                    tmp = tmp.union(patch_mfile[patch])
            if tmp.issubset(files):
                break
            iter_files = tmp - files
            files = files.union(iter_files)
        for patch in group:
            patches.remove(patch)
        groups.append(list(group))

    i = 1
    for group in groups:
        print("The " + str(i) + " group: ")
        group.sort(reverse=True)
        for patch in group:
            print(patch.split('/')[-1])
        print("")
        i += 1


def do_patch_info(patch_files, args):
    i = 1
    for patch in patch_files:
        file_hdl = open_patch(patch)
        if not file_hdl:
            continue
        print("{}) {}".format(i, patch.split('/')[-1]))
        line = file_hdl.readline()
        while line:
            if line.startswith("--- ") and len(line.split()) > 1:
                prev_line = line
                line = file_hdl.readline()
                if not line:
                    break
                if line.startswith("+++ ") and len(line.split()) > 1:
                    source_file = ""
                    flag = ""
                    old_file = prev_line.split()[1]
                    new_file = line.split()[1]
                    del_lines = add_lines = 0
                    line = file_hdl.readline()
                    while line and not line.startswith("--- "):
                        if line.startswith("-") and not line.startswith("--"):
                            del_lines += 1
                        elif line.startswith("+"):
                            add_lines += 1
                        line = file_hdl.readline()
                    if old_file == "/dev/null": # add file
                        if len(new_file.split('/')) > 1:
                            new_file = '/'.join(new_file.split('/')[1:])
                        source_file = new_file
                        flag = "+"
                    elif new_file == "/dev/null": # delete file
                        if len(old_file.split('/')) > 1:
                            old_file = '/'.join(old_file.split('/')[1:])
                        source_file = old_file
                        flag = "-"
                    else: # modify file
                        if len(new_file.split('/')) > 1:
                            new_file = '/'.join(new_file.split('/')[1:])
                        source_file = new_file
                        flag = "*"
                    print("    {} {} (-{},+{})".format(flag, source_file, del_lines, add_lines))
            if not line.startswith("--- "):
                line = file_hdl.readline()
        file_hdl.close()
        i += 1


def do_main(args):
    patch_files = get_patch_files(args.patch)
    if not patch_files:
        print("no such patch files")
        sys.exit(1)
    patch_files.sort()
    if args.command == "check":
        do_patch_check(patch_files, args)
    elif args.command == "group":
        do_patch_group(patch_files, args)
    elif args.command == "info":
        do_patch_info(patch_files, args)


def main():
    args = parse_command_line()
    do_main(args)


if __name__ == "__main__":
    main()
